{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "<img src=\"../Pierian-Data-Logo.PNG\">\n",
    "<br>\n",
    "<strong><center>Copyright 2019. Created by Jose Marcial Portilla.</center></strong>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Full Artificial Neural Network Code Along - CLASSIFICATION\n",
    "In the last section we took in four continuous variables (lengths) to perform a classification. In this section we'll combine continuous and categorical data to perform a similar classification. The goal is to estimate the relative cost of a New York City cab ride from several inputs. The inspiration behind this code along is a recent <a href='https://www.kaggle.com/c/new-york-city-taxi-fare-prediction'>Kaggle competition</a>.\n",
    "\n",
    "<div class=\"alert alert-success\"><strong>NOTE:</strong> This notebook differs from the previous regression notebook in that it uses <tt>'fare_class'</tt> for the <tt><strong>y</strong></tt> set, and the output contains two values instead of one. In this exercise we're training our model to perform a binary classification, and predict whether a fare is greater or less than $10.00.</div>\n",
    "\n",
    "## Working with tabular data\n",
    "Deep learning with neural networks is often associated with sophisticated image recognition, and in upcoming sections we'll train models based on properties like pixels patterns and colors.\n",
    "\n",
    "Here we're working with tabular data (spreadsheets, SQL tables, etc.) with columns of values that may or may not be relevant. As it happens, neural networks can learn to make connections we probably wouldn't have developed on our own. However, to do this we have to handle categorical values separately from continuous ones. Make sure to watch the theory lectures! You'll want to be comfortable with:\n",
    "* continuous vs. categorical values\n",
    "* embeddings\n",
    "* batch normalization\n",
    "* dropout layers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perform standard imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the NYC Taxi Fares dataset\n",
    "The <a href='https://www.kaggle.com/c/new-york-city-taxi-fare-prediction'>Kaggle competition</a> provides a dataset with about 55 million records. The data contains only the pickup date & time, the latitude & longitude (GPS coordinates) of the pickup and dropoff locations, and the number of passengers. It is up to the contest participant to extract any further information. For instance, does the time of day matter? The day of the week? How do we determine the distance traveled from pairs of GPS coordinates?\n",
    "\n",
    "For this exercise we've whittled the dataset down to just 120,000 records from April 11 to April 24, 2010. The records are randomly sorted. We'll show how to calculate distance from GPS coordinates, and how to create a pandas datatime object from a text column. This will let us quickly get information like day of the week, am vs. pm, etc.\n",
    "\n",
    "Let's get started!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pickup_datetime</th>\n",
       "      <th>fare_amount</th>\n",
       "      <th>fare_class</th>\n",
       "      <th>pickup_longitude</th>\n",
       "      <th>pickup_latitude</th>\n",
       "      <th>dropoff_longitude</th>\n",
       "      <th>dropoff_latitude</th>\n",
       "      <th>passenger_count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2010-04-19 08:17:56 UTC</td>\n",
       "      <td>6.5</td>\n",
       "      <td>0</td>\n",
       "      <td>-73.992365</td>\n",
       "      <td>40.730521</td>\n",
       "      <td>-73.975499</td>\n",
       "      <td>40.744746</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2010-04-17 15:43:53 UTC</td>\n",
       "      <td>6.9</td>\n",
       "      <td>0</td>\n",
       "      <td>-73.990078</td>\n",
       "      <td>40.740558</td>\n",
       "      <td>-73.974232</td>\n",
       "      <td>40.744114</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2010-04-17 11:23:26 UTC</td>\n",
       "      <td>10.1</td>\n",
       "      <td>1</td>\n",
       "      <td>-73.994149</td>\n",
       "      <td>40.751118</td>\n",
       "      <td>-73.960064</td>\n",
       "      <td>40.766235</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2010-04-11 21:25:03 UTC</td>\n",
       "      <td>8.9</td>\n",
       "      <td>0</td>\n",
       "      <td>-73.990485</td>\n",
       "      <td>40.756422</td>\n",
       "      <td>-73.971205</td>\n",
       "      <td>40.748192</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2010-04-17 02:19:01 UTC</td>\n",
       "      <td>19.7</td>\n",
       "      <td>1</td>\n",
       "      <td>-73.990976</td>\n",
       "      <td>40.734202</td>\n",
       "      <td>-73.905956</td>\n",
       "      <td>40.743115</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           pickup_datetime  fare_amount  fare_class  pickup_longitude  \\\n",
       "0  2010-04-19 08:17:56 UTC          6.5           0        -73.992365   \n",
       "1  2010-04-17 15:43:53 UTC          6.9           0        -73.990078   \n",
       "2  2010-04-17 11:23:26 UTC         10.1           1        -73.994149   \n",
       "3  2010-04-11 21:25:03 UTC          8.9           0        -73.990485   \n",
       "4  2010-04-17 02:19:01 UTC         19.7           1        -73.990976   \n",
       "\n",
       "   pickup_latitude  dropoff_longitude  dropoff_latitude  passenger_count  \n",
       "0        40.730521         -73.975499         40.744746                1  \n",
       "1        40.740558         -73.974232         40.744114                1  \n",
       "2        40.751118         -73.960064         40.766235                2  \n",
       "3        40.756422         -73.971205         40.748192                1  \n",
       "4        40.734202         -73.905956         40.743115                1  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('../Data/NYCTaxiFares.csv')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    80000\n",
       "1    40000\n",
       "Name: fare_class, dtype: int64"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['fare_class'].value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Conveniently, 2/3 of the data have fares under \\\\$10, and 1/3 have fares \\\\$10 and above."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fare classes correspond to fare amounts as follows:\n",
    "<table style=\"display: inline-block\">\n",
    "<tr><th>Class</th><th>Values</th></tr>\n",
    "<tr><td>0</td><td>< \\$10.00</td></tr>\n",
    "<tr><td>1</td><td>>= \\$10.00</td></tr>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate the distance traveled\n",
    "The <a href='https://en.wikipedia.org/wiki/Haversine_formula'>haversine formula</a> calculates the distance on a sphere between two sets of GPS coordinates.<br>\n",
    "Here we assign latitude values with $\\varphi$ (phi) and longitude with $\\lambda$ (lambda).\n",
    "\n",
    "The distance formula works out to\n",
    "\n",
    "${\\displaystyle d=2r\\arcsin \\left({\\sqrt {\\sin ^{2}\\left({\\frac {\\varphi _{2}-\\varphi _{1}}{2}}\\right)+\\cos(\\varphi _{1})\\:\\cos(\\varphi _{2})\\:\\sin ^{2}\\left({\\frac {\\lambda _{2}-\\lambda _{1}}{2}}\\right)}}\\right)}$\n",
    "\n",
    "where\n",
    "\n",
    "$\\begin{split} r&: \\textrm {radius of the sphere (Earth's radius averages 6371 km)}\\\\\n",
    "\\varphi_1, \\varphi_2&: \\textrm {latitudes of point 1 and point 2}\\\\\n",
    "\\lambda_1, \\lambda_2&: \\textrm {longitudes of point 1 and point 2}\\end{split}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def haversine_distance(df, lat1, long1, lat2, long2):\n",
    "    \"\"\"\n",
    "    Calculates the haversine distance between 2 sets of GPS coordinates in df\n",
    "    \"\"\"\n",
    "    r = 6371  # average radius of Earth in kilometers\n",
    "       \n",
    "    phi1 = np.radians(df[lat1])\n",
    "    phi2 = np.radians(df[lat2])\n",
    "    \n",
    "    delta_phi = np.radians(df[lat2]-df[lat1])\n",
    "    delta_lambda = np.radians(df[long2]-df[long1])\n",
    "     \n",
    "    a = np.sin(delta_phi/2)**2 + np.cos(phi1) * np.cos(phi2) * np.sin(delta_lambda/2)**2\n",
    "    c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))\n",
    "    d = (r * c) # in kilometers\n",
    "\n",
    "    return d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pickup_datetime</th>\n",
       "      <th>fare_amount</th>\n",
       "      <th>fare_class</th>\n",
       "      <th>pickup_longitude</th>\n",
       "      <th>pickup_latitude</th>\n",
       "      <th>dropoff_longitude</th>\n",
       "      <th>dropoff_latitude</th>\n",
       "      <th>passenger_count</th>\n",
       "      <th>dist_km</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2010-04-19 08:17:56 UTC</td>\n",
       "      <td>6.5</td>\n",
       "      <td>0</td>\n",
       "      <td>-73.992365</td>\n",
       "      <td>40.730521</td>\n",
       "      <td>-73.975499</td>\n",
       "      <td>40.744746</td>\n",
       "      <td>1</td>\n",
       "      <td>2.126312</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2010-04-17 15:43:53 UTC</td>\n",
       "      <td>6.9</td>\n",
       "      <td>0</td>\n",
       "      <td>-73.990078</td>\n",
       "      <td>40.740558</td>\n",
       "      <td>-73.974232</td>\n",
       "      <td>40.744114</td>\n",
       "      <td>1</td>\n",
       "      <td>1.392307</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2010-04-17 11:23:26 UTC</td>\n",
       "      <td>10.1</td>\n",
       "      <td>1</td>\n",
       "      <td>-73.994149</td>\n",
       "      <td>40.751118</td>\n",
       "      <td>-73.960064</td>\n",
       "      <td>40.766235</td>\n",
       "      <td>2</td>\n",
       "      <td>3.326763</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2010-04-11 21:25:03 UTC</td>\n",
       "      <td>8.9</td>\n",
       "      <td>0</td>\n",
       "      <td>-73.990485</td>\n",
       "      <td>40.756422</td>\n",
       "      <td>-73.971205</td>\n",
       "      <td>40.748192</td>\n",
       "      <td>1</td>\n",
       "      <td>1.864129</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2010-04-17 02:19:01 UTC</td>\n",
       "      <td>19.7</td>\n",
       "      <td>1</td>\n",
       "      <td>-73.990976</td>\n",
       "      <td>40.734202</td>\n",
       "      <td>-73.905956</td>\n",
       "      <td>40.743115</td>\n",
       "      <td>1</td>\n",
       "      <td>7.231321</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           pickup_datetime  fare_amount  fare_class  pickup_longitude  \\\n",
       "0  2010-04-19 08:17:56 UTC          6.5           0        -73.992365   \n",
       "1  2010-04-17 15:43:53 UTC          6.9           0        -73.990078   \n",
       "2  2010-04-17 11:23:26 UTC         10.1           1        -73.994149   \n",
       "3  2010-04-11 21:25:03 UTC          8.9           0        -73.990485   \n",
       "4  2010-04-17 02:19:01 UTC         19.7           1        -73.990976   \n",
       "\n",
       "   pickup_latitude  dropoff_longitude  dropoff_latitude  passenger_count  \\\n",
       "0        40.730521         -73.975499         40.744746                1   \n",
       "1        40.740558         -73.974232         40.744114                1   \n",
       "2        40.751118         -73.960064         40.766235                2   \n",
       "3        40.756422         -73.971205         40.748192                1   \n",
       "4        40.734202         -73.905956         40.743115                1   \n",
       "\n",
       "    dist_km  \n",
       "0  2.126312  \n",
       "1  1.392307  \n",
       "2  3.326763  \n",
       "3  1.864129  \n",
       "4  7.231321  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['dist_km'] = haversine_distance(df,'pickup_latitude', 'pickup_longitude', 'dropoff_latitude', 'dropoff_longitude')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add a datetime column and derive useful statistics\n",
    "By creating a datetime object, we can extract information like \"day of the week\", \"am vs. pm\" etc.\n",
    "Note that the data was saved in UTC time. Our data falls in April of 2010 which occurred during Daylight Savings Time in New York. For that reason, we'll make an adjustment to EDT using UTC-4 (subtracting four hours)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pickup_datetime</th>\n",
       "      <th>fare_amount</th>\n",
       "      <th>fare_class</th>\n",
       "      <th>pickup_longitude</th>\n",
       "      <th>pickup_latitude</th>\n",
       "      <th>dropoff_longitude</th>\n",
       "      <th>dropoff_latitude</th>\n",
       "      <th>passenger_count</th>\n",
       "      <th>dist_km</th>\n",
       "      <th>EDTdate</th>\n",
       "      <th>Hour</th>\n",
       "      <th>AMorPM</th>\n",
       "      <th>Weekday</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2010-04-19 08:17:56 UTC</td>\n",
       "      <td>6.5</td>\n",
       "      <td>0</td>\n",
       "      <td>-73.992365</td>\n",
       "      <td>40.730521</td>\n",
       "      <td>-73.975499</td>\n",
       "      <td>40.744746</td>\n",
       "      <td>1</td>\n",
       "      <td>2.126312</td>\n",
       "      <td>2010-04-19 04:17:56</td>\n",
       "      <td>4</td>\n",
       "      <td>am</td>\n",
       "      <td>Mon</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2010-04-17 15:43:53 UTC</td>\n",
       "      <td>6.9</td>\n",
       "      <td>0</td>\n",
       "      <td>-73.990078</td>\n",
       "      <td>40.740558</td>\n",
       "      <td>-73.974232</td>\n",
       "      <td>40.744114</td>\n",
       "      <td>1</td>\n",
       "      <td>1.392307</td>\n",
       "      <td>2010-04-17 11:43:53</td>\n",
       "      <td>11</td>\n",
       "      <td>am</td>\n",
       "      <td>Sat</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2010-04-17 11:23:26 UTC</td>\n",
       "      <td>10.1</td>\n",
       "      <td>1</td>\n",
       "      <td>-73.994149</td>\n",
       "      <td>40.751118</td>\n",
       "      <td>-73.960064</td>\n",
       "      <td>40.766235</td>\n",
       "      <td>2</td>\n",
       "      <td>3.326763</td>\n",
       "      <td>2010-04-17 07:23:26</td>\n",
       "      <td>7</td>\n",
       "      <td>am</td>\n",
       "      <td>Sat</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2010-04-11 21:25:03 UTC</td>\n",
       "      <td>8.9</td>\n",
       "      <td>0</td>\n",
       "      <td>-73.990485</td>\n",
       "      <td>40.756422</td>\n",
       "      <td>-73.971205</td>\n",
       "      <td>40.748192</td>\n",
       "      <td>1</td>\n",
       "      <td>1.864129</td>\n",
       "      <td>2010-04-11 17:25:03</td>\n",
       "      <td>17</td>\n",
       "      <td>pm</td>\n",
       "      <td>Sun</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2010-04-17 02:19:01 UTC</td>\n",
       "      <td>19.7</td>\n",
       "      <td>1</td>\n",
       "      <td>-73.990976</td>\n",
       "      <td>40.734202</td>\n",
       "      <td>-73.905956</td>\n",
       "      <td>40.743115</td>\n",
       "      <td>1</td>\n",
       "      <td>7.231321</td>\n",
       "      <td>2010-04-16 22:19:01</td>\n",
       "      <td>22</td>\n",
       "      <td>pm</td>\n",
       "      <td>Fri</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           pickup_datetime  fare_amount  fare_class  pickup_longitude  \\\n",
       "0  2010-04-19 08:17:56 UTC          6.5           0        -73.992365   \n",
       "1  2010-04-17 15:43:53 UTC          6.9           0        -73.990078   \n",
       "2  2010-04-17 11:23:26 UTC         10.1           1        -73.994149   \n",
       "3  2010-04-11 21:25:03 UTC          8.9           0        -73.990485   \n",
       "4  2010-04-17 02:19:01 UTC         19.7           1        -73.990976   \n",
       "\n",
       "   pickup_latitude  dropoff_longitude  dropoff_latitude  passenger_count  \\\n",
       "0        40.730521         -73.975499         40.744746                1   \n",
       "1        40.740558         -73.974232         40.744114                1   \n",
       "2        40.751118         -73.960064         40.766235                2   \n",
       "3        40.756422         -73.971205         40.748192                1   \n",
       "4        40.734202         -73.905956         40.743115                1   \n",
       "\n",
       "    dist_km             EDTdate  Hour AMorPM Weekday  \n",
       "0  2.126312 2010-04-19 04:17:56     4     am     Mon  \n",
       "1  1.392307 2010-04-17 11:43:53    11     am     Sat  \n",
       "2  3.326763 2010-04-17 07:23:26     7     am     Sat  \n",
       "3  1.864129 2010-04-11 17:25:03    17     pm     Sun  \n",
       "4  7.231321 2010-04-16 22:19:01    22     pm     Fri  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['EDTdate'] = pd.to_datetime(df['pickup_datetime'].str[:19]) - pd.Timedelta(hours=4)\n",
    "df['Hour'] = df['EDTdate'].dt.hour\n",
    "df['AMorPM'] = np.where(df['Hour']<12,'am','pm')\n",
    "df['Weekday'] = df['EDTdate'].dt.strftime(\"%a\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Timestamp('2010-04-11 00:00:10')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['EDTdate'].min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Timestamp('2010-04-24 23:59:42')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['EDTdate'].max()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Separate categorical from continuous columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['pickup_datetime', 'fare_amount', 'fare_class', 'pickup_longitude',\n",
       "       'pickup_latitude', 'dropoff_longitude', 'dropoff_latitude',\n",
       "       'passenger_count', 'dist_km', 'EDTdate', 'Hour', 'AMorPM', 'Weekday'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_cols = ['Hour', 'AMorPM', 'Weekday']\n",
    "cont_cols = ['pickup_latitude', 'pickup_longitude', 'dropoff_latitude', 'dropoff_longitude', 'passenger_count', 'dist_km']\n",
    "y_col = ['fare_class']  # this column contains the labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\"><strong>NOTE:</strong> If you plan to use all of the columns in the data table, there's a shortcut to grab the remaining continuous columns:<br>\n",
    "<pre style='background-color:rgb(217,237,247)'>cont_cols = [col for col in df.columns if col not in cat_cols + y_col]</pre>\n",
    "\n",
    "Here we entered the continuous columns explicitly because there are columns we're not running through the model (fare_amount and EDTdate)</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Categorify\n",
    "Pandas offers a <a href='https://pandas.pydata.org/pandas-docs/stable/user_guide/categorical.html'><strong>category dtype</strong></a> for converting categorical values to numerical codes. A dataset containing months of the year will be assigned 12 codes, one for each month. These will usually be the integers 0 to 11. Pandas replaces the column values with codes, and retains an index list of category values. In the steps ahead we'll call the categorical values \"names\" and the encodings \"codes\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert our three categorical columns to category dtypes.\n",
    "for cat in cat_cols:\n",
    "    df[cat] = df[cat].astype('category')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "pickup_datetime              object\n",
       "fare_amount                 float64\n",
       "fare_class                    int64\n",
       "pickup_longitude            float64\n",
       "pickup_latitude             float64\n",
       "dropoff_longitude           float64\n",
       "dropoff_latitude            float64\n",
       "passenger_count               int64\n",
       "dist_km                     float64\n",
       "EDTdate              datetime64[ns]\n",
       "Hour                       category\n",
       "AMorPM                     category\n",
       "Weekday                    category\n",
       "dtype: object"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.dtypes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that <tt>df['Hour']</tt> is a categorical feature by displaying some of the rows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0     4\n",
       "1    11\n",
       "2     7\n",
       "3    17\n",
       "4    22\n",
       "Name: Hour, dtype: category\n",
       "Categories (24, int64): [0, 1, 2, 3, ..., 20, 21, 22, 23]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['Hour'].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here our categorical names are the integers 0 through 23, for a total of 24 unique categories. These values <em>also</em> correspond to the codes assigned to each name.\n",
    "\n",
    "We can access the category names with <tt>Series.cat.categories</tt> or just the codes with <tt>Series.cat.codes</tt>. This will make more sense if we look at <tt>df['AMorPM']</tt>:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    am\n",
       "1    am\n",
       "2    am\n",
       "3    pm\n",
       "4    pm\n",
       "Name: AMorPM, dtype: category\n",
       "Categories (2, object): [am, pm]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['AMorPM'].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['am', 'pm'], dtype='object')"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['AMorPM'].cat.categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    0\n",
       "1    0\n",
       "2    0\n",
       "3    1\n",
       "4    1\n",
       "dtype: int8"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['AMorPM'].head().cat.codes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['Fri', 'Mon', 'Sat', 'Sun', 'Thu', 'Tue', 'Wed'], dtype='object')"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['Weekday'].cat.categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    1\n",
       "1    2\n",
       "2    2\n",
       "3    3\n",
       "4    0\n",
       "dtype: int8"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['Weekday'].head().cat.codes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\"><strong>NOTE: </strong>NaN values in categorical data are assigned a code of -1. We don't have any in this particular dataset.</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we want to combine the three categorical columns into one input array using <a href='https://docs.scipy.org/doc/numpy/reference/generated/numpy.stack.html'><tt>numpy.stack</tt></a> We don't want the Series index, just the values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 4,  0,  1],\n",
       "       [11,  0,  2],\n",
       "       [ 7,  0,  2],\n",
       "       [17,  1,  3],\n",
       "       [22,  1,  0]], dtype=int8)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hr = df['Hour'].cat.codes.values\n",
    "ampm = df['AMorPM'].cat.codes.values\n",
    "wkdy = df['Weekday'].cat.codes.values\n",
    "\n",
    "cats = np.stack([hr, ampm, wkdy], 1)\n",
    "\n",
    "cats[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\"><strong>NOTE:</strong> This can be done in one line of code using a list comprehension:\n",
    "<pre style='background-color:rgb(217,237,247)'>cats = np.stack([df[col].cat.codes.values for col in cat_cols], 1)</pre>\n",
    "\n",
    "Don't worry about the dtype for now, we can make it int64 when we convert it to a tensor.</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert numpy arrays to tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 4,  0,  1],\n",
       "        [11,  0,  2],\n",
       "        [ 7,  0,  2],\n",
       "        [17,  1,  3],\n",
       "        [22,  1,  0]])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Convert categorical variables to a tensor\n",
    "cats = torch.tensor(cats, dtype=torch.int64)\n",
    "# this syntax is ok, since the source data is an array, not an existing tensor\n",
    "\n",
    "cats[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can feed all of our continuous variables into the model as a tensor. We're not normalizing the values here; we'll let the model perform this step.\n",
    "<div class=\"alert alert-info\"><strong>NOTE:</strong> We have to store <tt>conts</tt> and <tt>y</tt> as Float (float32) tensors, not Double (float64) in order for batch normalization to work properly.</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 40.7305, -73.9924,  40.7447, -73.9755,   1.0000,   2.1263],\n",
       "        [ 40.7406, -73.9901,  40.7441, -73.9742,   1.0000,   1.3923],\n",
       "        [ 40.7511, -73.9941,  40.7662, -73.9601,   2.0000,   3.3268],\n",
       "        [ 40.7564, -73.9905,  40.7482, -73.9712,   1.0000,   1.8641],\n",
       "        [ 40.7342, -73.9910,  40.7431, -73.9060,   1.0000,   7.2313]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Convert continuous variables to a tensor\n",
    "conts = np.stack([df[col].values for col in cont_cols], 1)\n",
    "conts = torch.tensor(conts, dtype=torch.float)\n",
    "conts[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'torch.FloatTensor'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conts.type()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: the CrossEntropyLoss function we'll use below expects a 1d y-tensor, so we'll replace <tt>.reshape(-1,1)</tt> with <tt>.flatten()</tt> this time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 0, 1, 0, 1])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Convert labels to a tensor\n",
    "y = torch.tensor(df[y_col].values).flatten()\n",
    "\n",
    "y[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([120000, 3])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cats.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([120000, 6])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([120000])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set an embedding size\n",
    "The rule of thumb for determining the embedding size is to divide the number of unique entries in each column by 2, but not to exceed 50."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(24, 12), (2, 1), (7, 4)]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This will set embedding sizes for Hours, AMvsPM and Weekdays\n",
    "cat_szs = [len(df[col].cat.categories) for col in cat_cols]\n",
    "emb_szs = [(size, min(50, (size+1)//2)) for size in cat_szs]\n",
    "emb_szs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a TabularModel\n",
    "This somewhat follows the <a href='https://docs.fast.ai/tabular.models.html'>fast.ai library</a> The goal is to define a model based on the number of continuous columns (given by <tt>conts.shape[1]</tt>) plus the number of categorical columns and their embeddings (given by <tt>len(emb_szs)</tt> and <tt>emb_szs</tt> respectively). The output would either be a regression (a single float value), or a classification (a group of bins and their softmax values). For this exercise our output will be a single regression value. Note that we'll assume our data contains both categorical and continuous data. You can add boolean parameters to your own model class to handle a variety of datasets."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\"><strong>Let's walk through the steps we're about to take. See below for more detailed illustrations of the steps.</strong><br>\n",
    "\n",
    "1. Extend the base Module class, set up the following parameters:\n",
    "   * <tt>emb_szs: </tt>list of tuples: each categorical variable size is paired with an embedding size\n",
    "   * <tt>n_cont:  </tt>int: number of continuous variables\n",
    "   * <tt>out_sz:  </tt>int: output size\n",
    "   * <tt>layers:  </tt>list of ints: layer sizes\n",
    "   * <tt>p:       </tt>float: dropout probability for each layer (for simplicity we'll use the same value throughout)\n",
    "   \n",
    "<tt><font color=black>class TabularModel(nn.Module):<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;def \\_\\_init\\_\\_(self, emb_szs, n_cont, out_sz, layers, p=0.5):<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;super().\\_\\_init\\_\\_()</font></tt><br>\n",
    "\n",
    "2. Set up the embedded layers with <a href='https://pytorch.org/docs/stable/nn.html#modulelist'><tt><strong>torch.nn.ModuleList()</strong></tt></a> and <a href='https://pytorch.org/docs/stable/nn.html#embedding'><tt><strong>torch.nn.Embedding()</strong></tt></a><br>Categorical data will be filtered through these Embeddings in the forward section.<br>\n",
    "<tt><font color=black>&nbsp;&nbsp;&nbsp;&nbsp;self.embeds = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])</font></tt><br><br>\n",
    "3. Set up a dropout function for the embeddings with <a href='https://pytorch.org/docs/stable/nn.html#dropout'><tt><strong>torch.nn.Dropout()</strong></tt></a> The default p-value=0.5<br>\n",
    "<tt><font color=black>&nbsp;&nbsp;&nbsp;&nbsp;self.emb_drop = nn.Dropout(emb_drop)</font></tt><br><br>\n",
    "4. Set up a normalization function for the continuous variables with <a href='https://pytorch.org/docs/stable/nn.html#batchnorm1d'><tt><strong>torch.nn.BatchNorm1d()</strong></tt></a><br>\n",
    "<tt><font color=black>&nbsp;&nbsp;&nbsp;&nbsp;self.bn_cont = nn.BatchNorm1d(n_cont)</font></tt><br><br>\n",
    "5. Set up a sequence of neural network layers where each level includes a Linear function, an activation function (we'll use <a href='https://pytorch.org/docs/stable/nn.html#relu'><strong>ReLU</strong></a>), a normalization step, and a dropout layer. We'll combine the list of layers with <a href='https://pytorch.org/docs/stable/nn.html#sequential'><tt><strong>torch.nn.Sequential()</strong></tt></a><br>\n",
    "<tt><font color=black>&nbsp;&nbsp;&nbsp;&nbsp;self.bn_cont = nn.BatchNorm1d(n_cont)<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;layerlist = []<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;n_emb = sum((nf for ni,nf in emb_szs))<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;n_in = n_emb + n_cont<br>\n",
    "<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;for i in layers:<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;layerlist.append(nn.Linear(n_in,i)) <br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;layerlist.append(nn.ReLU(inplace=True))<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;layerlist.append(nn.BatchNorm1d(i))<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;layerlist.append(nn.Dropout(p))<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;n_in = i<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;layerlist.append(nn.Linear(layers[-1],out_sz))<br>\n",
    "<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;self.layers = nn.Sequential(*layerlist)</font></tt><br><br>\n",
    "6. Define the forward method. Preprocess the embeddings and normalize the continuous variables before passing them through the layers.<br>Use <a href='https://pytorch.org/docs/stable/torch.html#torch.cat'><tt><strong>torch.cat()</strong></tt></a> to combine multiple tensors into one.<br>\n",
    "<tt><font color=black>def forward(self, x_cat, x_cont):<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;embeddings = []<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;for i,e in enumerate(self.embeds):<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;embeddings.append(e(x_cat[:,i]))<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;x = torch.cat(embeddings, 1)<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;x = self.emb_drop(x)<br>\n",
    "<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;x_cont = self.bn_cont(x_cont)<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;x = torch.cat([x, x_cont], 1)<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;x = self.layers(x)<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;return x</font></tt>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-danger\"><strong>Breaking down the embeddings steps</strong> (this code is for illustration purposes only.)</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 4,  0,  1],\n",
       "        [11,  0,  2],\n",
       "        [ 7,  0,  2],\n",
       "        [17,  1,  3]])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This is our source data\n",
    "catz = cats[:4]\n",
    "catz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(24, 12), (2, 1), (7, 4)]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This is passed in when the model is instantiated\n",
    "emb_szs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ModuleList(\n",
       "  (0): Embedding(24, 12)\n",
       "  (1): Embedding(2, 1)\n",
       "  (2): Embedding(7, 4)\n",
       ")"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This is assigned inside the __init__() method\n",
    "selfembeds = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])\n",
    "selfembeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0, Embedding(24, 12)), (1, Embedding(2, 1)), (2, Embedding(7, 4))]"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(enumerate(selfembeds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[ 0.0347,  0.3536, -1.2988,  1.6375, -0.0542, -0.2099,  0.3044, -1.2855,\n",
       "           0.8831, -0.7109, -0.9646, -0.1356],\n",
       "         [-0.5039, -0.9924,  1.2296, -0.6908,  0.4641, -1.0487,  0.5577, -1.1560,\n",
       "           0.8318, -0.0834,  1.2123, -0.6210],\n",
       "         [ 0.3509,  0.2216,  0.3432,  1.4547, -0.8747,  1.6727, -0.6417, -1.0160,\n",
       "           0.8217, -1.0531,  0.8357, -0.0637],\n",
       "         [ 0.7978,  0.4566,  1.0926, -0.4095, -0.3366,  1.0216,  0.3601, -0.2927,\n",
       "           0.3536,  0.2170, -1.4778, -1.1965]], grad_fn=<EmbeddingBackward>),\n",
       " tensor([[-0.9676],\n",
       "         [-0.9676],\n",
       "         [-0.9676],\n",
       "         [-1.0656]], grad_fn=<EmbeddingBackward>),\n",
       " tensor([[-2.1762,  1.0210,  1.3557, -0.1804],\n",
       "         [-1.0131,  0.9989, -0.4746, -0.1461],\n",
       "         [-1.0131,  0.9989, -0.4746, -0.1461],\n",
       "         [-0.3646, -3.2237, -0.9956,  0.2598]], grad_fn=<EmbeddingBackward>)]"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This happens inside the forward() method\n",
    "embeddingz = []\n",
    "for i,e in enumerate(selfembeds):\n",
    "    embeddingz.append(e(catz[:,i]))\n",
    "embeddingz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0347,  0.3536, -1.2988,  1.6375, -0.0542, -0.2099,  0.3044, -1.2855,\n",
       "          0.8831, -0.7109, -0.9646, -0.1356, -0.9676, -2.1762,  1.0210,  1.3557,\n",
       "         -0.1804],\n",
       "        [-0.5039, -0.9924,  1.2296, -0.6908,  0.4641, -1.0487,  0.5577, -1.1560,\n",
       "          0.8318, -0.0834,  1.2123, -0.6210, -0.9676, -1.0131,  0.9989, -0.4746,\n",
       "         -0.1461],\n",
       "        [ 0.3509,  0.2216,  0.3432,  1.4547, -0.8747,  1.6727, -0.6417, -1.0160,\n",
       "          0.8217, -1.0531,  0.8357, -0.0637, -0.9676, -1.0131,  0.9989, -0.4746,\n",
       "         -0.1461],\n",
       "        [ 0.7978,  0.4566,  1.0926, -0.4095, -0.3366,  1.0216,  0.3601, -0.2927,\n",
       "          0.3536,  0.2170, -1.4778, -1.1965, -1.0656, -0.3646, -3.2237, -0.9956,\n",
       "          0.2598]], grad_fn=<CatBackward>)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# We concatenate the embedding sections (12,1,4) into one (17)\n",
    "z = torch.cat(embeddingz, 1)\n",
    "z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This was assigned under the __init__() method\n",
    "selfembdrop = nn.Dropout(.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000,  0.0000, -2.1647,  0.0000, -0.0000, -0.3498,  0.5073, -2.1424,\n",
       "          0.0000, -1.1848, -1.6076, -0.2259, -1.6127, -3.6271,  0.0000,  2.2594,\n",
       "         -0.3007],\n",
       "        [-0.8398, -0.0000,  0.0000, -0.0000,  0.7734, -1.7478,  0.0000, -1.9267,\n",
       "          0.0000, -0.1390,  0.0000, -1.0350, -0.0000, -0.0000,  1.6648, -0.0000,\n",
       "         -0.2435],\n",
       "        [ 0.0000,  0.3693,  0.5719,  0.0000, -1.4578,  0.0000, -1.0694, -1.6933,\n",
       "          0.0000, -1.7552,  1.3929, -0.1062, -1.6127, -1.6886,  1.6648, -0.0000,\n",
       "         -0.0000],\n",
       "        [ 1.3297,  0.0000,  0.0000, -0.0000, -0.0000,  0.0000,  0.0000, -0.4879,\n",
       "          0.0000,  0.0000, -2.4631, -1.9941, -1.7760, -0.6077, -5.3728, -1.6593,\n",
       "          0.4330]], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z = selfembdrop(z)\n",
    "z"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-danger\"><strong>This is how the categorical embeddings are passed into the layers.</strong></div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TabularModel(nn.Module):\n",
    "\n",
    "    def __init__(self, emb_szs, n_cont, out_sz, layers, p=0.5):\n",
    "        super().__init__()\n",
    "        self.embeds = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])\n",
    "        self.emb_drop = nn.Dropout(p)\n",
    "        self.bn_cont = nn.BatchNorm1d(n_cont)\n",
    "        \n",
    "        layerlist = []\n",
    "        n_emb = sum((nf for ni,nf in emb_szs))\n",
    "        n_in = n_emb + n_cont\n",
    "        \n",
    "        for i in layers:\n",
    "            layerlist.append(nn.Linear(n_in,i)) \n",
    "            layerlist.append(nn.ReLU(inplace=True))\n",
    "            layerlist.append(nn.BatchNorm1d(i))\n",
    "            layerlist.append(nn.Dropout(p))\n",
    "            n_in = i\n",
    "        layerlist.append(nn.Linear(layers[-1],out_sz))\n",
    "            \n",
    "        self.layers = nn.Sequential(*layerlist)\n",
    "    \n",
    "    def forward(self, x_cat, x_cont):\n",
    "        embeddings = []\n",
    "        for i,e in enumerate(self.embeds):\n",
    "            embeddings.append(e(x_cat[:,i]))\n",
    "        x = torch.cat(embeddings, 1)\n",
    "        x = self.emb_drop(x)\n",
    "        \n",
    "        x_cont = self.bn_cont(x_cont)\n",
    "        x = torch.cat([x, x_cont], 1)\n",
    "        x = self.layers(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(33)\n",
    "model = TabularModel(emb_szs, conts.shape[1], 2, [200,100], p=0.4) # out_sz = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TabularModel(\n",
       "  (embeds): ModuleList(\n",
       "    (0): Embedding(24, 12)\n",
       "    (1): Embedding(2, 1)\n",
       "    (2): Embedding(7, 4)\n",
       "  )\n",
       "  (emb_drop): Dropout(p=0.4)\n",
       "  (bn_cont): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (layers): Sequential(\n",
       "    (0): Linear(in_features=23, out_features=200, bias=True)\n",
       "    (1): ReLU(inplace)\n",
       "    (2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Dropout(p=0.4)\n",
       "    (4): Linear(in_features=200, out_features=100, bias=True)\n",
       "    (5): ReLU(inplace)\n",
       "    (6): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (7): Dropout(p=0.4)\n",
       "    (8): Linear(in_features=100, out_features=2, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define loss function & optimizer\n",
    "For our classification we'll replace the MSE loss function with <a href='https://pytorch.org/docs/stable/nn.html#crossentropyloss'><strong><tt>torch.nn.CrossEntropyLoss()</tt></strong></a><br>\n",
    "For the optimizer, we'll continue to use <a href='https://pytorch.org/docs/stable/optim.html#torch.optim.Adam'><strong><tt>torch.optim.Adam()</tt></strong></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perform train/test splits\n",
    "At this point our batch size is the entire dataset of 120,000 records. To save time we'll use the first 60,000. Recall that our tensors are already randomly shuffled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 60000\n",
    "test_size = 12000\n",
    "\n",
    "cat_train = cats[:batch_size-test_size]\n",
    "cat_test = cats[batch_size-test_size:batch_size]\n",
    "con_train = conts[:batch_size-test_size]\n",
    "con_test = conts[batch_size-test_size:batch_size]\n",
    "y_train = y[:batch_size-test_size]\n",
    "y_test = y[batch_size-test_size:batch_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "48000"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(cat_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12000"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(cat_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model\n",
    "Expect this to take 30 minutes or more! We've added code to tell us the duration at the end."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:   1  loss: 0.73441482\n",
      "epoch:  26  loss: 0.45090991\n",
      "epoch:  51  loss: 0.35915938\n",
      "epoch:  76  loss: 0.31940848\n",
      "epoch: 101  loss: 0.29913244\n",
      "epoch: 126  loss: 0.28824982\n",
      "epoch: 151  loss: 0.28091952\n",
      "epoch: 176  loss: 0.27713534\n",
      "epoch: 201  loss: 0.27236161\n",
      "epoch: 226  loss: 0.27171907\n",
      "epoch: 251  loss: 0.26830241\n",
      "epoch: 276  loss: 0.26365638\n",
      "epoch: 300  loss: 0.25949642\n",
      "\n",
      "Duration: 709 seconds\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "start_time = time.time()\n",
    "\n",
    "epochs = 300\n",
    "losses = []\n",
    "\n",
    "for i in range(epochs):\n",
    "    i+=1\n",
    "    y_pred = model(cat_train, con_train)\n",
    "    loss = criterion(y_pred, y_train)\n",
    "    losses.append(loss)\n",
    "    \n",
    "    # a neat trick to save screen space:\n",
    "    if i%25 == 1:\n",
    "        print(f'epoch: {i:3}  loss: {loss.item():10.8f}')\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "print(f'epoch: {i:3}  loss: {loss.item():10.8f}') # print the last line\n",
    "print(f'\\nDuration: {time.time() - start_time:.0f} seconds') # print the time elapsed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot the loss function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEKCAYAAAD9xUlFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xt8HXWd//HX5+R+T9qkaZo0Ta/Q0pbS1gIFERWlooIuq6LiKusu7i7ouu66i+t6Q3e9rausP0RZFxdFRVBZKsKiXOR+aQu0lF7TG01bmqRpm1tzOcnn98eZxpAm6WnpyZyT834+HueRmTlzJp9h6Hln5vud75i7IyIiAhAJuwAREUkeCgURERmgUBARkQEKBRERGaBQEBGRAQoFEREZoFAQEZEBCgURERmgUBARkQGZYRdwosrLy72uri7sMkREUsqaNWua3b3ieOulXCjU1dWxevXqsMsQEUkpZrYrnvV0+UhERAYoFEREZIBCQUREBigURERkgEJBREQGKBRERGSAQkFERAakTSis2tnC1/9vE3r8qIjIyNImFNY1HOamP2zjUGdv2KWIiCSttAmFqpJcAPYd7gq5EhGR5JU2oVBZHAuF/a0KBRGRkaRNKEwOzhReUSiIiIwobUJhUlEOZvCKLh+JiIwobUIhKyNCeWGOQkFEZBRpEwoAk4tzdflIRGQU6RUKJblqaBYRGUV6hYLOFERERpVeoVCSy6HOXo709IVdiohIUkqrUJhSGuuWuvfwkZArERFJTgkNBTNbYWabzazezK4b5v1vm9kLwWuLmR1KZD01ZfkANBxUKIiIDCczURs2swzgRuAtQAOwysxWuvuGo+u4+98NWv/jwFmJqgegpiwPgIaDnYn8NSIiKSuRZwrLgHp33+7uPcDtwGWjrP9+4OcJrIdJRblkRow9OlMQERlWIkOhGtg9aL4hWHYMM5sGTAceGuH9q81stZmtbmpqOumCMiLGlNI8XT4SERlBIkPBhlk20sMMrgB+6e7Ddgty95vdfam7L62oqHhNRVWX5unykYjICBIZCg3A1EHzNcDeEda9ggRfOhoookxnCiIiI0lkKKwCZpvZdDPLJvbFv3LoSmZ2GlAGPJXAWgbUlOXT2NZNV6/uVRARGSphoeDuUeBa4H5gI3CHu79kZteb2aWDVn0/cLuP0XMyp05QDyQRkZEkrEsqgLvfC9w7ZNnnh8x/MZE1DDW9vACAHc2dzJpUNJa/WkQk6aXVHc0wOBTaQ65ERCT5pF0olOZnU5afxY5mXT4SERkq7UIBYmcLOlMQETlWWoZCXXkBO3WmICJyjLQMhRnlBbzS2kVnTzTsUkREkkpahkJd0NisswURkVdLy1D4Yw+kjpArERFJLmkZCnUTgzOFAwoFEZHB0jIUCnIyqSzOYXuTQkFEZLC0DAWInS3oTEFE5NXSNhRmVBSoTUFEZIi0DYW6iQW0dPRwuLM37FJERJJG2obCnMmxwfA2vtIaciUiIskjbUNh/pQSANbvORxyJSIiySNtQ6GiKIfJxbkKBRGRQdI2FADmVxezfq8uH4mIHJXmoVDCtqZ2Oro1BpKICKR5KMyrKsYdtuxvC7sUEZGkkNahcFrQA0mhICISk9ahMLUsn9ysCJtf0QN3REQgzUMhEjFmTypia6POFEREIM1DAWBOZRGbX1EoiIiAQoE5lYU0tnVzqLMn7FJEREKnUBhobFa7goiIQqEyFgqb1QNJREShMKUkl8KcTLYqFEREFApmxpzKQjU2i4igUABil5C27G/D3cMuRUQkVAoFYqFwsLOXpvbusEsREQmVQgGYW1UMwAaNmCoiaU6hQGwIbYB1DXq2goikN4UCUJSbxYyKAoWCiKS9hIaCma0ws81mVm9m142wznvNbIOZvWRmP0tkPaM5s6aUF/ccCuvXi4gkhYSFgpllADcCbwPmAe83s3lD1pkNfAY4z93PAD6ZqHqOZ0F1Cftbu9nf2hVWCSIioTtuKJjZN8ys2MyyzOxBM2s2syvj2PYyoN7dt7t7D3A7cNmQdf4SuNHdDwK4e+OJ7sCpsnhaGQDP7mgJqwQRkdDFc6bwVndvBd4BNABzgE/H8blqYPeg+YZg2WBzgDlm9oSZPW1mK+LYbkLMn1JMUU4mT247EFYJIiKhy4xjnazg5yXAz929xczi2fZwKw29OywTmA1cCNQAj5nZfHd/1cV9M7sauBqgtrY2nt99wjIzIpw9YyJPbmtOyPZFRFJBPGcKvzGzTcBS4EEzqwDiufDeAEwdNF8D7B1mnbvdvdfddwCbiYXEq7j7ze6+1N2XVlRUxPGrT87ymRPZdaCTPYeOJOx3iIgks+OGgrtfB5wLLHX3XqCDY9sGhrMKmG1m080sG7gCWDlknf8F3ghgZuXELidtj7/8U2tJ0K7wYoN6IYlIeoqnofk9QNTd+8zsX4DbgCnH+5y7R4FrgfuBjcAd7v6SmV1vZpcGq90PHDCzDcDDwKfdPbSL+qdNLiJisGGfBscTkfQUT5vC59z9TjM7H7gY+HfgJuDs433Q3e8F7h2y7PODph34VPAKXW5WBtPLC9i0T8NdiEh6iqdNoS/4+XbgJne/G8hOXEnhmltVzMZXFAoikp7iCYU9ZvYD4L3AvWaWE+fnUtLcqmJ2txyhras37FJERMZcPF/u7yV27X9F0FV0AvHdp5CS5gUjpm5Uu4KIpKF4eh91AtuAi83sWmCSu/8u4ZWFZH51CQDr1ANJRNJQPL2P/hb4KTApeN1mZh9PdGFhqSjKYUpJLms1YqqIpKF4eh99FDjb3TsAzOzrwFPAdxNZWJgW1pTqTEFE0lI8bQrGH3sgEUzHNc5Fqjpzaim7DnRyqLMn7FJERMZUPGcKPwKeMbO7gvl3AbckrqTwLZpaCsCaXQd589zKkKsRERk78TQ0/wdwFdACHASucvdvJ7qwMJ1VW0puVoTHtmpwPBFJL/GcKeDuzwHPHZ03s5fdPTHDlSaB3KwMzp4+kUe3NoVdiojImDrZm9DGdZsCwAVzKtje1EHDwc6wSxERGTMnGwpDn4sw7pw7YyIQa1cQEUkXI14+MrORBqkzoDAx5SSPOZWF5GRGWNdwmMsWDX1gnIjI+DRam0LRKO/dcKoLSTaZGRHOmFLMi7qJTUTSyIih4O5fGstCktHCmlLuWL2bvn4nIzLum1FERMbvaKenwoLqEjp7+tje1B52KSIiY0KhMIqldbHHc/5hs7qmikh6iGdAvIyxKCQZTZtYwILqElau3Rt2KSIiYyKeM4V6M/ummc1LeDVJ6NIzp/DinsPsaO4IuxQRkYSLJxQWAluAH5rZ02Z2tZkVJ7iupPGWebGxj57adiDkSkREEi+esY/a3P2/3H058I/AF4B9Znarmc1KeIUhmzYxn+LcTF7co66pIjL+xdWmYGaXBqOk3gB8C5gB/Aa4N8H1hc7MWFBTwnqFgoikgXgGxNsKPAx8092fHLT8l2Z2QWLKSi7zq0u45fEddEf7yMlM23Z3EUkD8YTCQncftqO+u3/iFNeTlBZUl9Db52x5pZ0FNSVhlyMikjDxNDRPMrPfmFmzmTWa2d1mNiPhlSWRs2pj9yus2tkSciUiIokVTyj8DLgDmAxMAe4Efp7IopJNdWkeM8oL9HwFERn34npGs7v/xN2jwes20mDo7KEumFPB09sP0B3tO/7KIiIpKp5QeNjMrjOzOjObZmb/CPzWzCaY2YREF5gsXj+7nK7eftbs1PMVRGT8iqeh+X3Bz48NWf7nxM4Y0qJ94ZwZE8nKMB7d2szyWeVhlyMikhDHDQV3nz4WhSS7gpxMFteW8eiWJq572+lhlyMikhDx3LyWZWafMLNfBq9rzSxrLIpLNhfMqWDDvlaa2rrDLkVEJCHiaVO4CVgCfC94LQmWpZ0LZlcA8ER9c8iViIgkRjxtCq9z9zMHzT9kZmsTVVAyO2NKMRMKsnl0axPvOkvPbRaR8SeeM4U+M5t5dCa4cS2ufplmtsLMNptZvZldN8z7HzGzJjN7IXj9Rfylj71IxDh/VjmPbW3GPe165YpIGojnTOHTxLqlbgcMmAZcdbwPBQ/nuRF4C9AArDKzle6+Yciqv3D3a0+s7PC8fnY5K9fuZdMrbcytSpsRxEUkTYwaCmYWAY4As4HTiIXCJnePp6V1GVDv7tuDbd0OXAYMDYWUcsGcWLvCo1uaFAoiMu6MevnI3fuBb7l7t7uvc/e1cQYCQDWwe9B8Q7BsqMvNbF3Qs2nqcBsKHuyz2sxWNzWFO9REZXEup1UW8dhWNTaLyPgTT5vC78zscjOzE9z2cOsPvRD/G6DO3RcCDwC3Drchd7/Z3Ze6+9KKiooTLOPUe/3scp7d2cKRHg15ISLjSzyh8Clig+B1m1mrmbWZWWscn2sABv/lXwPsHbyCux8YdObxX8S6uya918+poCfazzM79IhOERlf4nkcZ5G7R9w9292Lg/l4LqavAmab2XQzywauAFYOXsHMqgbNXgpsPJHiw3L29AnkZ2fw+w37wy5FROSUiueO5gfjWTaUu0eBa4H7iX3Z3+HuL5nZ9WZ2abDaJ8zspeC+h08AHzmR4sOSm5XBG0+bxP0v7aevX11TRWT8GLH3kZnlAvlAuZmV8cc2gmJiz1U4Lne/lyHPcXb3zw+a/gzwmROsOSmsmD+Z3764jzW7DrJsetoMFisi49xoXVI/BnySWACs4Y+h0Ers/oO0duFpFUQMHq9vViiIyLgxYii4+w3ADWb2cXf/7hjWlBKKcrM4fXIxz+3S8xVEZPyIZ+js75rZcqBu8Pru/uME1pUSlkwr467n99DX72RETrTHrohI8omnofknwL8D5wOvC15LE1xXSlg8rZT27ihb9reFXYqIyCkRz9hHS4F5rhHgjrGkNtaW8ER9s4a8EJFxIZ6b19YDkxNdSCqqnZjPWbWl3Pb0LvrVNVVExoF4QqEc2GBm95vZyqOvRBeWKv78vOnsPNDJw5sbwy5FROQ1i+fy0RcTXUQqWzF/MpOLc7nliR28eW5l2OWIiLwmI54pmNnpAO7+CPC0uz9y9AXoIcWBrIwIf7Z8Gk/UH2DzK2pwFpHUNtrlo58Nmn5qyHvfS0AtKev9r6slNyvC/zy5I+xSRERek9FCwUaYHm4+rZUVZPPus2r49XN7aOnoCbscEZGTNloo+AjTw82nvavOq6M72s+dq3cff2URkSQ1WkNzjZn9J7GzgqPTBPPDPUEtrc2pLGLR1FJWrt3Lx94wM+xyREROymih8OlB06uHvDd0XoB3njmFL9+zge1N7cyoKAy7HBGREzbagHjDPhpTRvb2BVV85bcb+NVzDXz64tPDLkdE5ITFc/OaxGlySS4Xz5vMT57aRVtXb9jliIicMIXCKfY3b5xJa1eU259Vg7OIpB6Fwim2sKaUxbWl3LF6NxpDUERSTTxDZ3/DzIrNLMvMHjSzZjO7ciyKS1V/sriGrY3trN/TGnYpIiInJJ4zhbe6eyvwDqABmMOreybJEO9cOIXszAi/eq4h7FJERE5IPKGQFfy8BPi5u7cksJ5xoSQ/i7fMrWTl2r30RPvDLkdEJG7xhMJvzGwTsYftPGhmFUBXYstKfZcvqaalo4c/aEhtEUkhxw0Fd78OOBdY6u69QAdwWaILS3Wvn13BpKIcbnlCg+SJSOqIp6H5PUDU3fvM7F+A24ApCa8sxWVlRPirN8zk6e0tPLXtQNjliIjEJZ7LR59z9zYzOx+4GLgVuCmxZY0PHzi7lokF2fz4qZ1hlyIiEpd4QqEv+Pl24CZ3vxvITlxJ40duVgbvWFjFQ5saae+Ohl2OiMhxxRMKe8zsB8B7gXvNLCfOzwnwjjOn0B3t5/cbXgm7FBGR44rny/29wP3ACnc/BExA9ynEbUltGdPLC/jafZvYd/hI2OWIiIwqnt5HncA24GIzuxaY5O6/S3hl40QkYtx05WLau6J89q71YZcjIjKqeHof/S3wU2BS8LrNzD6e6MLGk9MnF3PNm2bx0KZG1uzSvX8ikrziuXz0UeBsd/+8u38eOAf4y8SWNf58ZHkd5YU5fPP+zRooT0SSVjyhYPyxBxLBtCWmnPErPzuTa94Yu2/hiXrdtyAiySmeUPgR8IyZfdHMvgg8Dfx3PBs3sxVmttnM6s3sulHW+1MzczNbGlfVKeoDZ9dSXZrHv927kb5+nS2ISPKJp6H5P4CrgBbgIHCVu3/neJ8zswzgRuBtwDzg/WY2b5j1ioBPAM+cWOmpJyczg3962+ls2NfKz57ZFXY5IiLHGDUUzCxiZuvd/Tl3/093v8Hdn49z28uAenff7u49wO0MP2bSl4FvkCaD7L1zYRXnzyrny/ds5JntuowkIsll1FBw935grZnVnsS2q4HBz6RsCJYNMLOzgKnufs9JbD8lmRk3fmAxU0pz+dzd69XoLCJJJZ42hSrgpeCpayuPvuL43HCN0QPfgGYWAb4N/P1xN2R2tZmtNrPVTU1Ncfzq5FaSn8XfXjSbLfvb+cPm1N8fERk/MuNY50snue0GYOqg+Rpg76D5ImA+8AczA5gMrDSzS9199eANufvNwM0AS5cuHRd/Wr9j4RT+/f4tXPuz5/js2+fxgbNP5mRMROTUGvFMwcxmmdl57v7I4Bexv/bjec7kKmC2mU03s2zgCmDgDMPdD7t7ubvXuXsdsV5NxwTCeJWVEeG2vzibRbWlfO7u9azZdTDskkRERr189B2gbZjlncF7o3L3KHAtsXGTNgJ3uPtLZna9mV16MsWON9PLC7jpyiVUleTyyV88T1tXb9gliUiaGy0U6tx93dCFwV/ydfFs3N3vdfc57j7T3f81WPZ5dz+mTcLdL0yXs4TBinOz+M77FrHn4BG+cPdLYZcjImlutFDIHeW9vFNdSDpbWjeBj79pNr9+fg+/WhPPlTkRkcQYLRRWmdkxYxyZ2UeBNYkrKT19/E2zeF1dGX9/51q+94f6sMsRkTQ1Wu+jTwJ3mdkH+WMILCX21LV3J7qwdJOZEeHHf342H//5c9zwwFY+uGwaJflZYZclImlmxDMFd9/v7suJdUndGby+5O7nurseI5YAedkZfPKiOXRH+/nfF/aEXY6IpKHj3qfg7g8DD49BLQLMry5hYU0JX71vI/WN7fzVhTOpLlUTjoiMDT1rOQl974OLeefCKdy+6mXec9OTdEf7jv8hEZFTQKGQhGrK8vnme87klo+8jr2Hu7jt6Zc1RpKIjAmFQhI7f1Y5y6ZP4Mv3bOCt336Uhzc3hl2SiIxzCoUkZmbc/KElfOnSM+hz52M/XkN943A3mYuInBoKhSRXmp/Nh5fX8YurzyUvO4NP3bGWxra0ePSEiIRAoZAiKopy+PrlC9n8Shvv+M/H2bivNeySRGQcUiikkBXzJ/O/15xHxIzLb3qS7z+yjd0tnWGXJSLjiEIhxcytKuaua5azuLaMr923iUtueIyDHT1hlyUi44RCIQVVleTxk48u446PnUtbd5Sfr3qZ3r7+sMsSkXFAoZCizIxl0ydw7oyJfOP/NrPoS7/jtqd3hV2WiKQ4hUKK++zb5/KR5XWcVVvGv/zveh7atD/skkQkhcXzjGZJYvOrS5hfXUJ3tI93fvdxPnXHWj50zjT+7qI5RCIWdnkikmJ0pjBO5GRm8L0PLmZBdQnffaheo6yKyElRKIwjsyYVcetVy1hYU8LX7tvErU/u5MaH6zWgnojETaEwzkQixr+9ewFZGRG+sPIlvnn/Zq756fM8tGm/BtUTkeOyVPuiWLp0qa9evTrsMpJeb18/L7d08uDG/Xz1vk24ww1XLOKyRdVhlyYiITCzNe6+9LjrKRTGv9auXj70w2d4uaWT0ycXs7+ti+9fuYQ5lUVhlyYiYyTeUNDlozRQnJvF1y5fyMyKQo709tHS0cPf/eIFeqK64U1EXk1dUtPE3KpifvnXywH43UuvcPVP1vC1+zYxt6qIotwsLpo7icwM/Y0gku4UCmnorWdM5j1LarjliR0Dy963dCpf/ZMFmMXulhaR9KRQSFNfuPQMZk0q5LxZ5dyzbh/ff2Qb96zbiwNfvmw+ly+pCbtEEQmBQiFNFeZk8rE3zARgXlUxC6pLeHr7AdbtOcxnfv0iVaW5LKubwO827Of82eUU52aFXLGIjAX1PpJXOdjRw3t/8BQvt3RSOyGfrY3tvGNhFd9+3yK27m9nRkUBuVkZYZcpIidIXVLlpB1o7+Yrv93Iyy2dlOVn88DG/eRmRejq7WdeVTE3/9kSasrywy5TRE6AQkFOie5oH1+5ZyNZGRGqy/L4zgNbyMqI8IFltbzrrCnMmlREd7SP7IyIGqhFkphCQRJiW1M7/3DnWtY1HCbDjPNmTeTZHS0sqi3l2+9dxKTi3LBLFJFhKBQkoZrauvnW7zazetdB6ibm8+jWZnD4u7fM4a/eMANQ11aRZJIUoWBmK4AbgAzgh+7+tSHv/xVwDdAHtANXu/uG0bapUEhOO5o7+Pp9m/i/l16hMCeTc2dOZE5lIWX52Vx13nQy9GwHkVCFHgpmlgFsAd4CNACrgPcP/tI3s2J3bw2mLwX+xt1XjLZdhULyivb1c/09G9h7qIsHNv7xCXC5WRFqyvJZUF3CebPKufiMSorUxVVkTMUbCom8T2EZUO/u24OCbgcuAwZC4WggBAqA1LqWJa+SmRHh+svmA7By7V5K8rLo6I7y3K6D7DzQweP1zdz1/B7+4U5YOq2MT140h+UzJxKJGH39Tk+0n7xsdXcVCVMiQ6Ea2D1ovgE4e+hKZnYN8CkgG3hTAuuRMXTpmVMGpi9ZUAWAu7Nm10GeqD/Az57dxZX//QynTy7iuredzv97qJ4t+9v40mVn8O6zdDe1SFgSefnoPcDF7v4XwfyHgGXu/vER1v9AsP6Hh3nvauBqgNra2iW7du1KSM0ydrp6+7j3xX189b5NNLV1kxkxTq8qYv2eVv7y9dP5+7eeRkbEONjRw9qGw5xZU6KeTSKvQTK0KZwLfNHdLw7mPwPg7l8dYf0IcNDdS0bbrtoUxpfWrl5W72xhcnEecyoLuf6eDfz4qV2YQX5WBpGI0dYVJSNivGFOBXMqi7h8cTWz9SwIkROSDKGQSayh+c3AHmINzR9w95cGrTPb3bcG0+8EvnC8ohUK49+jW5p4evsBmtq6OXyklyvPmcZjW5t4aFMjL7d00tvnnD19AhfMqaCxtYtJxbk8ua2Zr/3JQmrK8mhq66aiKEddYkUGCT0UgiIuAb5DrEvqLe7+r2Z2PbDa3Vea2Q3ARUAvcBC4dnBoDEehkN6a27u5c3UDP31mFw0Hj5CdEaGnr5+MiDG1LI9pEwt4ZEsTUyfkceGcSWxramft7kPccMVZXDSvMuzyRUKTFKGQCAoFAejrdw60d1OQk8muA520d0e57lfr2Hv4CFeePY0tje08vf0AM8oL6Onrp6HlCIW5mXz64tMoL8yhJ9pPc3s3tRPy6enrZ8m0MsoLc8LeLZGEUShIWurvdyJDbpR75XAXX/ntBl5u6WRdw+FhP5eVYcyvLuFdi6q5c81uZpQXkpMZoaoklyO9fbzxtEksn1U+FrsgkhAKBZEhOnui/OiJnSyaWsqEgmwmFGSzs7mDSMR4YON+HtrYyNbGdsrys+iO9pOdGeHwkV4AMsyYUppHdWkeAFmZES47cwor5k+mub2b5vYeppblqYeUJC2FgsgJ6u3r567n93DujIlMKc0jYtDaFaW/3/nybzfQ0R2lvrGdnMwMuqJ9bG/qeNXnszMivGVeJVNKc3l6ewuOc8mCKp6ob2Z/azcXn1HJ5YtrmFySS3527BahaF8/PX39A/MiiaJQEEmg/n7nN+v2svdQF+WFsbOOBzc18uiWJprbu6mbWEBPtJ/tzR1UFucwe1IRT2xr5ug/t5zMCCV5WXT29NET7WfF/MlMnZDHgfYeWrt6KcvPZs2ug/zzJXMpzc+ipiyf3764j3OmT8DMmFCQTVl+lnpYSdwUCiIhc3c6e/rIC+632NncwbM7WjjQ0cPBzh5aj/SSlREh2t/PAxsbaW7vpjQvi4xIhAMdsemDnbHLV8W5mbR2RV+1/awMoyQvm77+fhbUlLJxXyuvn1VORXEOO5s7uPKcaZwzYyKrdrRQOzGfPQePcHpVMUU5mWxvbqextZu5VcUU52VpwMI0oFAQSTHujpkR7eunrStKtN95dEsTO5o7uOv5PXz27XNp74qSkxWhpaOHxrZuDnX20B3t59kdLZxWWcTT2w9wpLePkiBQ8rMz6OzpG/gdxbmZuENbdyxgsjMi9LuzuLaMypJc1uxsoafPufiMSi6aW8nkklz+4/dbyMvKoLosj5dbOplYkB0MbDiZtq5eImbkZWVw6Egvew8dISczQnlhDnsOHWFeVfGrGv77+52uaJ8ul4VAoSCSptydI719/OiJnexv7WLZ9AnsOXiEySW5/H7DfkrysjirtozK4hwe2dxEnzu/WbuPjAi8rm4CGRHjN2v30h98NZjFGtr73amdkE9zew/t3VEqi3PY39pNdmaEwpxMWjp6BtYvzMmkrStKVUkuF8yuYFFtKVv2t3H3C3s5fKSXhTUl9Pc7h4708u6zqlk2fQKrdx6ksjiHmRWFLJpayt5DXTz38kGWz5zItqYOFtSUUJiTSUd3lKyMCLsOdDBtYgHZmZFj9l+X1Y6lUBCRuB39Hjj6ZbqtqZ2Wjh6eqG/mtMoiZlcWkhmJUFdeQG9fPzc/up3Nr7Rx2uQimtq6aT3Sy/zqEsqLcli9s4XtTR1csqCKx7Y28Xh9M21dUTIjxor5k6kuzeP53YfIjBgRMx6vbz6mnsyIEe1/9XdTTmaEqRPyqW9sH1hWlJtJaX4WGWYsn1VOUW4mtz+7m/LCbE6bXMSkolyKczPJzc7gwjmT6OnrJ2IQMaMoN5PaCfm4x4ZnzogYXb19dHRHmVCQPe6CRaEgIkkh2tfP/rZuCrIzKM3PPub97U3trGs4zIWnVdDWFWX9nsM89/JBqkryOH1yEY/VNzOvqpgXdh9iW1M7C6tL6HeoKctjbcPhgS/yP2xpoifaz4WnVQCwu6WTxtZu2nuijPQ1N2tSIYc6e+jTSAYSAAAITElEQVTtc5bPnDjQ5rOguoQPnTONx+qbaWzt4vxZ5bz3dVPJzcygsa2L4rwsJqXYUCoKBRFJKz3Rfvrdyc169TM53J1Dnb3cvmo3EwuyKczNxIDmjh7uWLWb0vwsKotzeWxrEzVl+Vw0t5IfPradAx09FOdmMr28gLXD3PQ4pSSX82aV09nbR+uRXhZNLaU72s+W/W28+6xqJhbksOdQJ5csqKKju48dzR0smz4htEZ9hYKIyEnq6I6yv7WLKaV55GZlsHV/G49ubSba109VaR4t7d08trWZF/ccJicrQlFOFpv3t+HuVBTF2lqGM6O8gAvmVDCxIJsXdh9izuQifrtuH9e8cSbvWDiFgpxYA3x7d5SC7Axaj0TpjvYxqTiXrt6+YwLvRCgURETGUEd3lM6ePsrys1i35zAt7T2U5mfx6JYmivOyKMvP5pdrGnh+90G6evspy4/1ECsvzKG5vRszWFJbRk5WhCfqDzC9vICdBzpwh48sr+P3G/Zz3dtO552DHmB1IpLhcZwiImmjICdz4C/9xbVlA8uX1k0YmL58SQ29ff20dPQwoSCb9XsOs6C6hIc3N7F29yGe2NbMvsNdfPjcaWzY18qK+ZOpb2znf57cSVl+FrMrCxO+HzpTEBFJYp09Ub7zwFbefVY1c6uKT3o7OlMQERkH8rMz+edL5o7Z74scfxUREUkXCgURERmgUBARkQEKBRERGaBQEBGRAQoFEREZoFAQEZEBCgURERmQcnc0m1kTsOskP14OHDt4e2rSviQn7Uty0r7ANHevON5KKRcKr4WZrY7nNu9UoH1JTtqX5KR9iZ8uH4mIyACFgoiIDEi3ULg57AJOIe1LctK+JCftS5zSqk1BRERGl25nCiIiMoq0CQUzW2Fmm82s3syuC7ueE2VmO83sRTN7wcxWB8smmNnvzWxr8LPseNsJg5ndYmaNZrZ+0LJha7eY/wyO0zozWxxe5ccaYV++aGZ7gmPzgpldMui9zwT7stnMLg6n6mOZ2VQze9jMNprZS2b2t8HylDsuo+xLKh6XXDN71szWBvvypWD5dDN7JjguvzCz7GB5TjBfH7xf95qLcPdx/wIygG3ADCAbWAvMC7uuE9yHnUD5kGXfAK4Lpq8Dvh52nSPUfgGwGFh/vNqBS4D7AAPOAZ4Ju/449uWLwD8Ms+684P+1HGB68P9gRtj7ENRWBSwOpouALUG9KXdcRtmXVDwuBhQG01nAM8F/7zuAK4Ll3wf+Opj+G+D7wfQVwC9eaw3pcqawDKh39+3u3gPcDlwWck2nwmXArcH0rcC7QqxlRO7+KNAyZPFItV8G/NhjngZKzaxqbCo9vhH2ZSSXAbe7e7e77wDqif2/GDp33+fuzwXTbcBGoJoUPC6j7MtIkvm4uLu3B7NZwcuBNwG/DJYPPS5Hj9cvgTebmb2WGtIlFKqB3YPmGxj9f5pk5MDvzGyNmV0dLKt0930Q+4cBTAqtuhM3Uu2peqyuDS6r3DLoMl5K7EtwyeEsYn+VpvRxGbIvkILHxcwyzOwFoBH4PbEzmUPuHg1WGVzvwL4E7x8GJr6W358uoTBccqZat6vz3H0x8DbgGjO7IOyCEiQVj9VNwExgEbAP+FawPOn3xcwKgV8Bn3T31tFWHWZZsu9LSh4Xd+9z90VADbEzmOEe0Hy03lO+L+kSCg3A1EHzNcDekGo5Ke6+N/jZCNxF7H+W/UdP4YOfjeFVeMJGqj3ljpW77w/+IfcD/8UfL0Uk9b6YWRaxL9Gfuvuvg8UpeVyG25dUPS5Hufsh4A/E2hRKzSwzeGtwvQP7ErxfQvyXN4eVLqGwCpgdtOBnE2uQWRlyTXEzswIzKzo6DbwVWE9sHz4crPZh4O5wKjwpI9W+EvizoLfLOcDho5czktWQa+vvJnZsILYvVwQ9RKYDs4Fnx7q+4QTXnf8b2Oju/zHorZQ7LiPtS4oelwozKw2m84CLiLWRPAz8abDa0ONy9Hj9KfCQB63OJy3s1vaxehHrPbGF2PW5z4ZdzwnWPoNYb4m1wEtH6yd27fBBYGvwc0LYtY5Q/8+Jnb73EvvL5qMj1U7sdPjG4Di9CCwNu/449uUnQa3rgn+kVYPW/2ywL5uBt4Vd/6C6zid2mWEd8ELwuiQVj8so+5KKx2Uh8HxQ83rg88HyGcSCqx64E8gJlucG8/XB+zNeaw26o1lERAaky+UjERGJg0JBREQGKBRERGSAQkFERAYoFEREZIBCQWQMmdmFZnZP2HWIjEShICIiAxQKIsMwsyuDce1fMLMfBIOUtZvZt8zsOTN70MwqgnUXmdnTwcBrdw16BsEsM3sgGBv/OTObGWy+0Mx+aWabzOynr3VUS5FTSaEgMoSZzQXeR2wQwkVAH/BBoAB4zmMDEz4CfCH4yI+Bf3L3hcTuoD26/KfAje5+JrCc2J3QEBvF85PExvWfAZyX8J0SiVPm8VcRSTtvBpYAq4I/4vOIDQzXD/wiWOc24NdmVgKUuvsjwfJbgTuDsaqq3f0uAHfvAgi296y7NwTzLwB1wOOJ3y2R41MoiBzLgFvd/TOvWmj2uSHrjTZGzGiXhLoHTfehf4eSRHT5SORYDwJ/amaTYOC5xdOI/Xs5OlLlB4DH3f0wcNDMXh8s/xDwiMfG828ws3cF28gxs/wx3QuRk6C/UESGcPcNZvYvxJ50FyE2Iuo1QAdwhpmtIfaEq/cFH/kw8P3gS387cFWw/EPAD8zs+mAb7xnD3RA5KRolVSROZtbu7oVh1yGSSLp8JCIiA3SmICIiA3SmICIiAxQKIiIyQKEgIiIDFAoiIjJAoSAiIgMUCiIiMuD/Aw3Fnn31yN6qAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(range(epochs), losses)\n",
    "plt.ylabel('Cross Entropy Loss')\n",
    "plt.xlabel('epoch');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validate the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CE Loss: 0.25455481\n"
     ]
    }
   ],
   "source": [
    "# TO EVALUATE THE ENTIRE TEST SET\n",
    "with torch.no_grad():\n",
    "    y_val = model(cat_test, con_test)\n",
    "    loss = criterion(y_val, y_test)\n",
    "print(f'CE Loss: {loss:.8f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's look at the first 50 predicted values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MODEL OUTPUT               ARGMAX  Y_TEST\n",
      "tensor([ 1.8140, -1.6443])    0      0   \n",
      "tensor([-1.8268,  2.6373])    1      0   \n",
      "tensor([ 1.4028, -1.9248])    0      0   \n",
      "tensor([-1.9130,  1.4853])    1      1   \n",
      "tensor([ 1.1757, -2.4964])    0      0   \n",
      "tensor([ 2.0996, -2.2990])    0      0   \n",
      "tensor([ 1.3226, -1.8349])    0      0   \n",
      "tensor([-1.6211,  2.3889])    1      1   \n",
      "tensor([ 2.2489, -2.4253])    0      0   \n",
      "tensor([-0.4459,  1.1358])    1      1   \n",
      "tensor([ 1.5145, -2.1619])    0      0   \n",
      "tensor([ 0.7704, -1.9443])    0      0   \n",
      "tensor([ 0.9637, -1.3796])    0      0   \n",
      "tensor([-1.3527,  1.7322])    1      1   \n",
      "tensor([ 1.4110, -2.4595])    0      0   \n",
      "tensor([-1.4455,  2.6081])    1      1   \n",
      "tensor([ 2.2798, -2.5864])    0      1   \n",
      "tensor([ 1.4585, -2.7982])    0      0   \n",
      "tensor([ 0.3342, -0.8995])    0      0   \n",
      "tensor([ 2.0525, -1.9737])    0      0   \n",
      "tensor([-1.3571,  2.1911])    1      1   \n",
      "tensor([-0.4669,  0.2872])    1      1   \n",
      "tensor([-2.0624,  2.2875])    1      1   \n",
      "tensor([-2.1334,  2.6416])    1      1   \n",
      "tensor([-3.1325,  5.1561])    1      1   \n",
      "tensor([ 2.2128, -2.5172])    0      0   \n",
      "tensor([ 1.0346, -1.7764])    0      0   \n",
      "tensor([ 1.1221, -1.6717])    0      0   \n",
      "tensor([-2.1322,  1.6714])    1      1   \n",
      "tensor([ 1.5009, -1.6338])    0      0   \n",
      "tensor([ 2.0387, -1.8475])    0      0   \n",
      "tensor([-1.6346,  2.8899])    1      1   \n",
      "tensor([-3.0129,  2.3519])    1      1   \n",
      "tensor([-1.5746,  2.0000])    1      1   \n",
      "tensor([ 1.3056, -2.2630])    0      0   \n",
      "tensor([ 0.6631, -1.4797])    0      0   \n",
      "tensor([-1.4585,  2.1836])    1      1   \n",
      "tensor([ 1.0574, -1.5848])    0      1   \n",
      "tensor([ 0.3376, -0.8050])    0      1   \n",
      "tensor([ 1.9217, -1.9764])    0      0   \n",
      "tensor([ 0.1011, -0.5529])    0      0   \n",
      "tensor([ 0.6703, -0.5540])    0      0   \n",
      "tensor([-0.6733,  0.8777])    1      1   \n",
      "tensor([ 2.2017, -2.0445])    0      0   \n",
      "tensor([-0.0442, -0.4276])    0      0   \n",
      "tensor([-1.1204,  1.2558])    1      1   \n",
      "tensor([-1.8170,  2.7124])    1      1   \n",
      "tensor([ 1.7404, -2.0341])    0      0   \n",
      "tensor([ 1.3266, -2.3039])    0      0   \n",
      "tensor([-0.0671,  0.3291])    1      0   \n",
      "\n",
      "45 out of 50 = 90.00% correct\n"
     ]
    }
   ],
   "source": [
    "rows = 50\n",
    "correct = 0\n",
    "print(f'{\"MODEL OUTPUT\":26} ARGMAX  Y_TEST')\n",
    "for i in range(rows):\n",
    "    print(f'{str(y_val[i]):26} {y_val[i].argmax():^7}{y_test[i]:^7}')\n",
    "    if y_val[i].argmax().item() == y_test[i]:\n",
    "        correct += 1\n",
    "print(f'\\n{correct} out of {rows} = {100*correct/rows:.2f}% correct')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save the model\n",
    "Save the trained model to a file in case you want to come back later and feed new data through it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make sure to save the model only after the training has happened!\n",
    "if len(losses) == epochs:\n",
    "    torch.save(model.state_dict(), 'TaxiFareClssModel.pt')\n",
    "else:\n",
    "    print('Model has not been trained. Consider loading a trained model instead.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading a saved model (starting from scratch)\n",
    "We can load the trained weights and biases from a saved model. If we've just opened the notebook, we'll have to run standard imports and function definitions. To demonstrate, restart the kernel before proceeding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "def haversine_distance(df, lat1, long1, lat2, long2):\n",
    "    r = 6371\n",
    "    phi1 = np.radians(df[lat1])\n",
    "    phi2 = np.radians(df[lat2])\n",
    "    delta_phi = np.radians(df[lat2]-df[lat1])\n",
    "    delta_lambda = np.radians(df[long2]-df[long1])\n",
    "    a = np.sin(delta_phi/2)**2 + np.cos(phi1) * np.cos(phi2) * np.sin(delta_lambda/2)**2\n",
    "    c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))\n",
    "    return r * c\n",
    "\n",
    "class TabularModel(nn.Module):\n",
    "    def __init__(self, emb_szs, n_cont, out_sz, layers, p=0.5):\n",
    "        super().__init__()\n",
    "        self.embeds = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])\n",
    "        self.emb_drop = nn.Dropout(p)\n",
    "        self.bn_cont = nn.BatchNorm1d(n_cont)\n",
    "        layerlist = []\n",
    "        n_emb = sum((nf for ni,nf in emb_szs))\n",
    "        n_in = n_emb + n_cont\n",
    "        for i in layers:\n",
    "            layerlist.append(nn.Linear(n_in,i)) \n",
    "            layerlist.append(nn.ReLU(inplace=True))\n",
    "            layerlist.append(nn.BatchNorm1d(i))\n",
    "            layerlist.append(nn.Dropout(p))\n",
    "            n_in = i\n",
    "        layerlist.append(nn.Linear(layers[-1],out_sz))\n",
    "        self.layers = nn.Sequential(*layerlist)\n",
    "    def forward(self, x_cat, x_cont):\n",
    "        embeddings = []\n",
    "        for i,e in enumerate(self.embeds):\n",
    "            embeddings.append(e(x_cat[:,i]))\n",
    "        x = torch.cat(embeddings, 1)\n",
    "        x = self.emb_drop(x)\n",
    "        x_cont = self.bn_cont(x_cont)\n",
    "        x = torch.cat([x, x_cont], 1)\n",
    "        return self.layers(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now define the model. Before we can load the saved settings, we need to instantiate our TabularModel with the parameters we used before (embedding sizes, number of continuous columns, output size, layer sizes, and dropout layer p-value). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_szs = [(24, 12), (2, 1), (7, 4)]\n",
    "model2 = TabularModel(emb_szs, 6, 2, [200,100], p=0.4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the model is set up, loading the saved settings is a snap."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TabularModel(\n",
       "  (embeds): ModuleList(\n",
       "    (0): Embedding(24, 12)\n",
       "    (1): Embedding(2, 1)\n",
       "    (2): Embedding(7, 4)\n",
       "  )\n",
       "  (emb_drop): Dropout(p=0.4)\n",
       "  (bn_cont): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (layers): Sequential(\n",
       "    (0): Linear(in_features=23, out_features=200, bias=True)\n",
       "    (1): ReLU(inplace)\n",
       "    (2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Dropout(p=0.4)\n",
       "    (4): Linear(in_features=200, out_features=100, bias=True)\n",
       "    (5): ReLU(inplace)\n",
       "    (6): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (7): Dropout(p=0.4)\n",
       "    (8): Linear(in_features=100, out_features=2, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model2.load_state_dict(torch.load('TaxiFareClssModel.pt'));\n",
    "model2.eval() # be sure to run this step!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we'll define a function that takes in new parameters from the user, performs all of the preprocessing steps above, and passes the new data through our trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_data(mdl): # pass in the name of the new model\n",
    "    # INPUT NEW DATA\n",
    "    plat = float(input('What is the pickup latitude?  '))\n",
    "    plong = float(input('What is the pickup longitude? '))\n",
    "    dlat = float(input('What is the dropoff latitude?  '))\n",
    "    dlong = float(input('What is the dropoff longitude? '))\n",
    "    psngr = int(input('How many passengers? '))\n",
    "    dt = input('What is the pickup date and time?\\nFormat as YYYY-MM-DD HH:MM:SS     ')\n",
    "    \n",
    "    # PREPROCESS THE DATA\n",
    "    dfx_dict = {'pickup_latitude':plat,'pickup_longitude':plong,'dropoff_latitude':dlat,\n",
    "         'dropoff_longitude':dlong,'passenger_count':psngr,'EDTdate':dt}\n",
    "    dfx = pd.DataFrame(dfx_dict, index=[0])\n",
    "    dfx['dist_km'] = haversine_distance(dfx,'pickup_latitude', 'pickup_longitude',\n",
    "                                        'dropoff_latitude', 'dropoff_longitude')\n",
    "    dfx['EDTdate'] = pd.to_datetime(dfx['EDTdate'])\n",
    "    \n",
    "    # We can skip the .astype(category) step since our fields are small,\n",
    "    # and encode them right away\n",
    "    dfx['Hour'] = dfx['EDTdate'].dt.hour\n",
    "    dfx['AMorPM'] = np.where(dfx['Hour']<12,0,1) \n",
    "    dfx['Weekday'] = dfx['EDTdate'].dt.strftime(\"%a\")\n",
    "    dfx['Weekday'] = dfx['Weekday'].replace(['Fri','Mon','Sat','Sun','Thu','Tue','Wed'],\n",
    "                                            [0,1,2,3,4,5,6]).astype('int64')\n",
    "    # CREATE CAT AND CONT TENSORS\n",
    "    cat_cols = ['Hour', 'AMorPM', 'Weekday']\n",
    "    cont_cols = ['pickup_latitude', 'pickup_longitude', 'dropoff_latitude',\n",
    "                 'dropoff_longitude', 'passenger_count', 'dist_km']\n",
    "    xcats = np.stack([dfx[col].values for col in cat_cols], 1)\n",
    "    xcats = torch.tensor(xcats, dtype=torch.int64)\n",
    "    xconts = np.stack([dfx[col].values for col in cont_cols], 1)\n",
    "    xconts = torch.tensor(xconts, dtype=torch.float)\n",
    "    \n",
    "    # PASS NEW DATA THROUGH THE MODEL WITHOUT PERFORMING A BACKPROP\n",
    "    with torch.no_grad():\n",
    "        z = mdl(xcats, xconts).argmax().item()\n",
    "    print(f'\\nThe predicted fare class is {z}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Feed new data through the trained model\n",
    "For convenience, here are the max and min values for each of the variables:\n",
    "<table style=\"display: inline-block\">\n",
    "<tr><th>Column</th><th>Minimum</th><th>Maximum</th></tr>\n",
    "<tr><td>pickup_latitude</td><td>40</td><td>41</td></tr>\n",
    "<tr><td>pickup_longitude</td><td>-74.5</td><td>-73.3</td></tr>\n",
    "<tr><td>dropoff_latitude</td><td>40</td><td>41</td></tr>\n",
    "<tr><td>dropoff_longitude</td><td>-74.5</td><td>-73.3</td></tr>\n",
    "<tr><td>passenger_count</td><td>1</td><td>5</td></tr>\n",
    "<tr><td>EDTdate</td><td>2010-04-11 00:00:00</td><td>2010-04-24 23:59:42</td></tr>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<strong>Use caution!</strong> The distance between 1 degree of latitude (from 40 to 41) is 111km (69mi) and between 1 degree of longitude (from -73 to -74) is 85km (53mi). The longest cab ride in the dataset spanned a difference of only 0.243 degrees latitude and 0.284 degrees longitude. The mean difference for both latitude and longitude was about 0.02. To get a fair prediction, use values that fall close to one another."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "What is the pickup latitude?  40.5\n",
      "What is the pickup longitude? -73.9\n",
      "What is the dropoff latitude?  40.52\n",
      "What is the dropoff longitude? -73.92\n",
      "How many passengers? 2\n",
      "What is the pickup date and time?\n",
      "Format as YYYY-MM-DD HH:MM:SS     2010-04-15 16:00:00\n",
      "\n",
      "The predicted fare class is 1\n"
     ]
    }
   ],
   "source": [
    "test_data(model2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perfect! Where our regression predicted a fare value of ~\\\\$14, our binary classification predicts a fare greater than $10.\n",
    "## Great job!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
